import argparse
import json
import os.path
import re
from json import JSONDecodeError


class TurboEdge:
    def __init__(self, source: 'TurboNode', target: 'TurboNode', index: int, edge_type: str):
        self.source = source
        self.target = target
        self.index = index
        self.type = edge_type

    def to_turbo(self) -> dict:
        return {
            'source': self.source.id,
            'target': self.target.id,
            'index': self.index,
            'type': self.type,
        }


class TurboNode:
    def __init__(self, node_id: int, label: str, title: str, opcode: str, control: bool,
                 opinfo: str = "", live: bool = True, properties: str = "") -> None:
        self.id = node_id
        self.label = label
        self.title = title
        self.opcode = opcode
        self.control = control
        self.opinfo = opinfo
        self.live = live
        self.properties = properties

        self.input_edges = []
        self.state_inputs = []
        self.depend_inputs = []
        self.value_inputs = []
        self.frame_state_inputs = []
        self.root_inputs = []
        self.outputs = []

    def create_input_edge(self, turbo_nodes_map: dict) -> None:
        for source_id in self.value_inputs:
            self.add_input(source_id, edge_type='value', turbo_nodes_map=turbo_nodes_map)
        for source_id in self.frame_state_inputs:
            self.add_input(source_id, edge_type='frame-state', turbo_nodes_map=turbo_nodes_map)
        for source_id in self.depend_inputs:
            self.add_input(source_id, edge_type='effect', turbo_nodes_map=turbo_nodes_map)
        for source_id in self.state_inputs:
            self.add_input(source_id, edge_type='control', turbo_nodes_map=turbo_nodes_map)
        for source_id in self.root_inputs:
            self.add_input(source_id, edge_type='root', turbo_nodes_map=turbo_nodes_map)

    def add_input(self, source_id: int, edge_type: str, turbo_nodes_map: dict) -> None:
        source = turbo_nodes_map.get(source_id)
        # If the source node does not exist, the edge is not generated.
        if source:
            turbo_edge = TurboEdge(source=source, target=self, index=len(self.input_edges), edge_type=edge_type)
            self.input_edges.append(turbo_edge)

    @classmethod
    def parse_mtype(cls, mtype: str):
        items = [item.strip() for item in mtype.split(',')]  # split by ',' and remove blank
        items = [item for item in items if item]  # remove string like "" generated by comma at the end
        mtype_dict = {}
        if items:
            mtype_dict['mtype'] = items[0]
            for item in items:
                if item.startswith('bitfield='):
                    bitfield = item.split('=')[1]
                    mtype_dict['bitfield'] = bitfield
                    try:
                        mtype_dict['bitfield_int'] = int(bitfield, 16)
                    except ValueError:
                        mtype_dict['bitfield_int'] = 0
                        print(f"\033[91mError: Failed to convert bitfield hexadecimal number {bitfield}\033[0m")
                elif item.startswith('type='):
                    mtype_dict['type'] = item.split('=')[1]
                elif item.startswith('stamp='):
                    mtype_dict['stamp'] = item.split('=')[1]
                elif item.startswith('mark='):
                    mtype_dict['mark'] = item.split('=')[1]
        return mtype_dict

    @classmethod
    def make_label_and_mtype(cls, cir_node: dict):
        op = cir_node.get('op')
        label = op
        mtype = TurboNode.parse_mtype(cir_node.get('MType'))
        if op == 'JS_BYTECODE':
            bytecode = cir_node.get('bytecode')
            label = f'{label}[{bytecode}]'
        elif op == 'ARG':
            index = mtype.get('bitfield_int')
            label = f'{label}[{index}]'
        elif op == 'CONSTANT':
            value = mtype.get('bitfield_int')
            label = f'{label}[{value}]'
        elif 'typedop' in cir_node:
            typedop = cir_node.get('typedop')
            label = f'{label}[{typedop}]'
        elif 'comment' in cir_node:
            comment = cir_node.get('comment')
            label = f'{label}[{comment}]'
        return label, mtype

    @classmethod
    def is_control_node(cls, op: str):
        return op in [
            'DEPEND_RELAY',
            'IF_BRANCH',
            'IF_TRUE',
            'IF_FALSE',
            'MERGE',
            'DEPEND_SELECTOR',
            'RETURN',
            'SWITCH_CASE',
            'DEOPT_CHECK',
            'SWITCH_BRANCH',
            'CIRCUIT_ROOT',
            'LOOP_BEGIN',
            'LOOP_BACK',
            'JS_BYTECODE',
            'EffectPhi',
        ]

    @classmethod
    def transform(cls, cir_node: dict) -> 'TurboNode':
        op = cir_node.get('op')
        label, mtype = TurboNode.make_label_and_mtype(cir_node)
        turbo_node = TurboNode(node_id=cir_node.get('id'), label=label, title=label, opcode=op,
                               control=TurboNode.is_control_node(op))
        inputs = cir_node.get('in')
        try:
            turbo_node.state_inputs = inputs[0]
            turbo_node.depend_inputs = inputs[1]
            turbo_node.value_inputs = inputs[2]
            turbo_node.frame_state_inputs = inputs[3]
            turbo_node.root_inputs = inputs[4]
        except IndexError:
            print(f"\033[91mError: Node {turbo_node.id}: {turbo_node.opcode} Missing a type of input edge\033[0m")
        turbo_node.outputs = cir_node.get('out')
        if not any([turbo_node.state_inputs,
                    turbo_node.depend_inputs,
                    turbo_node.value_inputs,
                    turbo_node.frame_state_inputs,
                    turbo_node.root_inputs,
                    turbo_node.outputs]):
            turbo_node.live = False

        input_output_info = f"{len(turbo_node.state_inputs)} state, {len(turbo_node.depend_inputs)} depend, {len(turbo_node.value_inputs)} value, {len(turbo_node.frame_state_inputs)} frame, {len(turbo_node.root_inputs)} root, {len(turbo_node.outputs)} out"
        turbo_node.opinfo = input_output_info + '\n' + cir_node.get('MType')

        return turbo_node

    def to_turbo(self) -> dict:
        return {
            'id': self.id,
            'label': self.label,
            'title': self.title,
            'live': self.live,
            'properties': self.properties,
            'opcode': self.opcode,
            'control': self.control,
            'opinfo': self.opinfo,
        }


def get_phase(name: str, turbo_nodes_map: dict) -> dict:
    phase = {
        'name': name,
        'type': 'graph',
        'data': {
            'nodes': [],
            'edges': [],
        },
    }
    for turbo_node in turbo_nodes_map.values():
        phase['data']['nodes'].append(turbo_node.to_turbo())
        for edge in turbo_node.input_edges:
            phase['data']['edges'].append(edge.to_turbo())
    return phase


def cir_parser(circuit_ir: dict) -> dict:
    turbo_nodes_map = {}
    phase_name = circuit_ir.get('name')
    cir_nodes = circuit_ir.get('nodes')
    # create nodes
    for node in cir_nodes:
        turbo_node = TurboNode.transform(node)
        turbo_nodes_map[turbo_node.id] = turbo_node
    # create edges
    for turbo_node in turbo_nodes_map.values():
        turbo_node.create_input_edge(turbo_nodes_map=turbo_nodes_map)
    return get_phase(phase_name, turbo_nodes_map=turbo_nodes_map)


def read_log_file(file_path: str) -> list:
    with open(file_path, 'r', encoding='utf-8') as file:
        content = file.read()
        # Only the content between
        # "=== End ==="
        # and the nearest previous beginning such as
        # "=== phase_name [function_name@abc_name]"
        # will be matched
        phase_pattern = r'(===.*\[.*\].*((?!===.*\[.*\].*)[\s\S])*?)=== End ==='
        phase_matches = re.findall(phase_pattern, content)
        phases = []
        for match in phase_matches:
            lines = match[0].splitlines()
            # To match the phase_name, function_name and abc_name
            phase_name_pattern = r'([^=]*)\[(.*)@(.*)\]'
            name_match = re.search(phase_name_pattern, lines[0])
            phase_name = "unknown phase"
            function_name = "unknown function"
            abc_name = "unknown abc"
            node_list = []
            if name_match:
                phase_name = name_match.group(1).strip()
                function_name = name_match.group(2).strip()
                abc_name = name_match.group(3).strip()
            print(f"Parsing ==={abc_name}=== ==={function_name}=== ==={phase_name}===")
            for line in lines:
                # Only match lines contains "{}" with keys 'id', 'op', 'MType', 'in' and 'out'
                item_pattern = r'\{.*?\}'
                item_match = re.search(item_pattern, line)
                if item_match:
                    try:
                        dict_item = json.loads(item_match.group(0))
                        required_keys = {'id', 'op', 'MType', 'in', 'out'}
                        if required_keys.issubset(dict_item):
                            node_list.append(dict_item)
                        else:
                            print(f"\033[91mError: JSON item {item_match.group(0)} is missing key in {required_keys}\033[0m")
                    except JSONDecodeError:
                        print(f"\033[91mError: Failed to decode JSON item: {item_match.group(0)}\033[0m")
            if phase_name == "After range analysis":
                continue
            phases.append({
                'abc_name': abc_name,
                'function_name': function_name,
                'name': phase_name,
                'nodes': node_list,
            })
        return phases


def make_turbo_input(turbo_phases: list, function_name: str) -> dict:
    turbo_input = {
        'function': {
            'sourceId': -1,
            'functionName': function_name,
            'sourceName': 'sourceName',
            'sourceText': '',
            'startPosition': 0,
            'endPosition': 0,
        },
        'phases': [
            {
                "name": "disassembly",
                "type": "disassembly",
                "blockIdToOffset": {},
                "data": ""
            }
        ],
        'nodePositions': {},
        'sources': {},
        'inlinings': {},
        'bytecodeSources': {},
    }
    for turbo_phase in turbo_phases:
        turbo_input['phases'].append(turbo_phase)
    return turbo_input


# The function name and abc file name may contain illegal file name characters.
# Use this method to remove them.
def sanitize_filename(filename: str) -> str:
    illegal_chars = r'[\\/*?:"<>|]+'
    sanitized = re.sub(illegal_chars, '', filename)
    sanitized = sanitized.strip().strip('.')
    return sanitized


if __name__ == '__main__':
    parser = argparse.ArgumentParser(
        description="Converts the circuit IR into a JSON file that can be used by the Turbolizer.")
    parser.add_argument("file_path", help="Path of input file")
    parser.add_argument("-d", "--directory", default="out", help="Output directory path")
    args = parser.parse_args()

    if args.directory is not None and not os.path.exists(args.directory):
        os.makedirs(args.directory)

    print(f"===Start to convert {args.file_path}===")
    file_name = os.path.splitext(os.path.basename(args.file_path))[0]
    circuit_ir = read_log_file(args.file_path)
    print(f"===Parsing {args.file_path} succeeded.===")

    turbo_phases_dict = {}
    turbo_phase_num = 0
    for cir_phase in circuit_ir:
        turbo_phase = cir_parser(cir_phase)
        function_abc_name = f"{sanitize_filename(cir_phase.get('function_name'))}@{sanitize_filename(cir_phase.get('abc_name'))}"
        if function_abc_name not in turbo_phases_dict:
            turbo_phases_dict[function_abc_name] = []
        turbo_phases_dict[function_abc_name].append(turbo_phase)
        print(
            f"{turbo_phase_num}: Successfully converted phase ==={cir_phase.get('abc_name')}=== ==={cir_phase.get('function_name')}=== ==={cir_phase.get('name')}===")
        turbo_phase_num += 1
    json_file_num = 0
    for function_abc_name, turbo_phase_list in turbo_phases_dict.items():
        # make output dir
        output_dir = os.path.join(args.directory, function_abc_name.split('@')[-1])
        if not os.path.exists(output_dir):
            os.mkdir(output_dir)
        # make output file
        output_file_name = f"{json_file_num}_{function_abc_name}.json"
        json_file_num += 1
        output_file_path = os.path.join(output_dir, output_file_name)
        # make phases
        turbo_input = make_turbo_input(turbo_phase_list, function_abc_name)
        print(f"===Exporting {output_file_path}===")
        with open(output_file_path, 'w') as file:
            file.write(json.dumps(turbo_input))

    print(f"===Finished===")
